"""Codes for Fig.3a plot 

Created by Yuechuan Lin 
03-20-2022 
Cornell University
"""
import os
from PyOCT import misc 
from h5py._hl.files import File
from matplotlib.pyplot import figure, step 
import numpy as np 
import matplotlib.pyplot as plt 
import matplotlib 
from matplotlib import cm 
import matplotlib.ticker as tck
import matplotlib.ticker 
from matplotlib import gridspec
import re 
import scipy
import h5py 
import time
from scipy import ndimage
from progress.bar import Bar
from scipy.signal.signaltools import resample_poly
import matplotlib.patches as patches
from mayavi import mlab
import mayavi as myv 
import pims
import trackpy as tp 
from tvtk.util.ctf import ColorTransferFunction
from tvtk.util import ctf
from tvtk.api import tvtk
import moviepy.editor as mpy
from scipy.io import loadmat
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm 
from matplotlib.colors import LightSource
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import cv2
from mlabtex import mlabtex 
from tvtk.pyface import light_manager
from traits.api import HasPrivateTraits, HasTraits, Any, Int, \
     Property, Instance, Event, Range, Bool, Trait, Str
# set font of plot 
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Helvetica']
font = {'weight': 'normal',
        'size'   : 8}
matplotlib.rc('font', **font)
matplotlib.rc('text', usetex=False)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

class MplColorHelper:
  def __init__(self, cmap_name, start_val, stop_val):
    self.cmap_name = cmap_name
    self.cmap = plt.get_cmap(cmap_name)
    self.norm = matplotlib.colors.Normalize(vmin=start_val, vmax=stop_val)
    self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap)
    print("color clim is {}".format(self.scalarMap.get_clim()))
  def get_rgb(self, val,alpha):
    return self.scalarMap.to_rgba(val,alpha=alpha,bytes=True)


def ChangeVolColormap(vol,cmapName, vmin, vmax, alpha):
    """
    Change colormap of mlab volume plot 
    input: 
    : vol: mlab.pipeline.volume return object 
    : cmapName: name, default is 'jet' 
    : vmin, vmax, alpha: parameters for cmap  
    """    
    cmapD = MplColorHelper(cmapName,start_val=vmin,stop_val=vmax)
    from tvtk.util.ctf import ColorTransferFunction
    ctf0 = ColorTransferFunction()
    #print(ctf0.__dict__)
    #ctf0.range = (0,255)
    colorValue = np.linspace(vmin,vmax,1024) #np.arange(vmin,vmax,step=0.1,dtype=np.float)
    print("Num of colormap values is {}".format(np.size(colorValue)))
    for i in range(np.size(colorValue)):     
        cT = cmapD.get_rgb(colorValue[i],alpha)
        ctf0.add_rgb_point(colorValue[i],cT[0]/255, cT[1]/255, cT[2]/255)  # r, g, and b are float between 0 and 1
    ctf0.nan_color = (0.0,0.0,0.0)
    ctf0.nan_opacity = 1.0 
    above_cT = cmapD.get_rgb(colorValue[-1],alpha)
    below_cT = cmapD.get_rgb(colorValue[0],alpha)
    ctf0.use_above_range_color = False 
    ctf0.use_below_range_color = False
    ctf0.above_range_color = (above_cT[0]/255,above_cT[1]/255,above_cT[2]/255)
    ctf0.below_range_color = (below_cT[0]/255,below_cT[1]/255,below_cT[2]/255)

    # save the color transfer function of the current volume
    c = ctf.save_ctfs(vol._volume_property)
    # change the alpha channel as needed
    c['alpha'][1][1] = alpha 
    # load the new color transfer function
    ctf.load_ctfs(c, vol._volume_property)
    vol._volume_property.set_color(ctf0)
    vol._ctf = ctf0
    vol.update_ctf = True
    
    return vol 


def LoadData(dataPath,cellPath,xOffset=0,verbose=True,depthRange=50):
    """
    return cooridnates geometry are y-plot is z-data, x-plot is x-data, and z-plot is y-data
    
    """
    fid_beads = h5py.File(dataPath,"r")
    results = np.transpose(fid_beads["results_beadwise"][()]) #(N,16) 
    if verbose:
        print("Number of beads before mask {}".format(np.shape(results)[0]))
    mask = np.any(np.isnan(results),axis=1) 
    results = results[~mask] #exclude nan 

    #Thresholding. Gr>30Pa, Gim>0 
    results_tmp = []
    for i in range(np.shape(results)[0]):
        if results[i,0] > -depthRange and results[i,0] < depthRange: # in range of [-40um,40um]
            if results[i,14] > 30 and results[i,15] > 0:
                results_tmp.append(results[i,:])
    
    results = np.asarray(results_tmp) 
    if verbose:
        print("Number of beads after mask {}".format(np.shape(results_tmp)[0]))
    
    y = results[:,0]
    x = results[:,1] - xOffset
    z = results[:,2]

    Gre = results[:,14]
    Gim = results[:,15]
    Gmag = np.sqrt(Gre**2 + Gim**2)
    R = Gim/Gre

   # fid_cell = loadmat(root_path+"/CellBody_210107.mat")
   # cellbody = fid_cell['cell']
    fid_cell = loadmat(cellPath) #h5py.File(cellPath,"r")
    cellBody = fid_cell["cell"]

    yc = cellBody[:,0]
    xc = cellBody[:,1] - xOffset 
    zc = cellBody[:,2] 
    stdc = cellBody[:,3] 
    intc = cellBody[:,4] 

    size_cell = np.size(yc)
    if size_cell%3 == 0:
        ext_cc = np.arange(0,size_cell,dtype=int)
        tr = int(size_cell / 3)
    elif (size_cell-1)%3 == 0:
        ext_cc = np.arange(0,size_cell-1,dtype=int) 
        tr = int((size_cell-1)/3)
    elif (size_cell-2)%3 == 0:
        ext_cc = np.arange(0,size_cell-2,dtype=int) 
        tr = int((size_cell-2)/3)
    else:
        raise ValueError("Nothing works for 3D data!")
    
    xc = np.reshape(xc[ext_cc],(tr,3))
    yc = np.reshape(yc[ext_cc],(tr,3))
    zc = np.reshape(zc[ext_cc],(tr,3))
    intc = np.reshape(intc[ext_cc],(tr,3))


    if verbose:
        print("# of Coordinates found in cell body is {}".format(np.shape(cellBody)[0]))
       # print(np.shape(xc))

    return y,x,z,Gre,R,yc,xc,zc,intc 





if __name__ == "__main__":
    # load beads-wise data 
    # [0z 1x 2y 3snr 4ampRatio 5stdA 6stdphi 7Atot 8phitot 9APT 10phiPT 11Amech 12phimech 13Frad 14GRe 15GIm]; % [#bead 16]
    root_path = "./Fig3_data/earlyCell4/eCell4"
    beads_fname = "BeadsRes_210107.mat"
    cell_fname = "CellBody_210107.mat"
    xoffset = 52 #CellIII 52, CellII 57. lightsheet center offset from beginning of FOV (um)
    saveFig = False # to save fig as .tif. If to write into gif, set to be False
    if saveFig:
        fpixSize =1024
    else:
        fpixSize =512
    
    y,x,z,Gre,R,yc,xc,zc,intc  = LoadData(root_path+"/"+beads_fname,root_path+"/"+cell_fname,xOffset=xoffset)
    intc = (intc - np.min(intc))/(np.max(intc)-np.min(intc))
    # plot cell body 
    f = mlab.figure(bgcolor=(0,0,0),size=(fpixSize,fpixSize))
    cellVolsrc = myv.tools.pipeline.scalar_field(xc,yc,zc,intc)
    vol = mlab.pipeline.volume(cellVolsrc,figure=f) 
    vol = ChangeVolColormap(vol,cmapName="jet",vmin=np.amin(intc),vmax=np.amax(intc),alpha=1.0)
    vol.update_pipeline()
    ax = mlab.axes(nb_labels=20,ranges=[np.amin(x),np.amax(x),np.amin(y),np.amax(y),np.amin(z),np.amax(z)])
    ax.axes.label_format = "%.0f"
    ax.axes.x_label = "X"
    ax.axes.y_label = "Z"
    ax.axes.z_label = "Y"
    if saveFig:
        ax.axes.visibility = 1
    else:
        ax.axes.visibility = 0
    #print(ax.axes.__dict__)
    oax = mlab.orientation_axes(xlabel="x",ylabel="z",zlabel="y") 
    oax._text_property.use_tight_bounding_box = True 
    oax.update_pipeline()
    mlab.show()



xFOV = np.asarray([np.amin(x),np.amax(x)],dtype=int)
yFOV = np.asarray([np.amin(y),np.amax(y)],dtype=int)
zFOV = np.asarray([np.amin(z),np.amax(z)],dtype=int)

print("xFOV is {} to {}".format(xFOV[0],xFOV[1]))
print("yFOV is {} to {}".format(yFOV[0],yFOV[1]))
print("zFOV is {} to {}".format(zFOV[0],zFOV[1]))

fid_cell = loadmat(root_path+"/CellBody_210107.mat")
cellbody = fid_cell['cell']
print(np.shape(cellbody)) #(13373,5) 
yc = cellbody[:,0]
xc = cellbody[:,1] - xoffset
zc = cellbody[:,2]

zc = zc.astype(int)
xc = xc.astype(int)
yc = yc.astype(int)


print("*****************")
print(np.amin(zc))
print(np.amax(zc))
print("Gre max is {} Pa".format(np.amax(Gre)))
print("Gre min is {} Pa".format(np.amin(Gre)))
print("*****************")

stdc = cellbody[:,3] # std 
intc = cellbody[:,4] # intensity 

print("z range of cell body is {} to {}".format(np.amin(zc),np.amax(zc)))
print("x range of cell body is {} to {}".format(np.amin(xc),np.amax(xc)))
print("y range of cell body is {} to {}".format(np.amin(yc),np.amax(yc)))

xgrid= np.linspace(int(np.amin(xFOV)),int(np.amax(xFOV)),int(np.amax(xFOV)-np.amin(xFOV)+1))
ygrid = np.linspace(int(np.amin(yFOV)),int(np.amax(yFOV)),int(np.amax(yFOV)-np.amin(yFOV)+1))
zgrid = np.linspace(int(np.amin(zFOV)),int(np.amax(zFOV)),int(np.amax(zFOV)-np.amin(zFOV)+1))
cellVol = np.zeros((int(np.amax(xFOV)-np.amin(xFOV)+1),int(np.amax(yFOV)-np.amin(yFOV)+1),int(np.amax(zFOV)-np.amin(zFOV)+1)),dtype=np.float32)

for i in range(np.size(xc)):
    tmpxc = xc[i] - np.amin(xFOV) 
    tmpyc = yc[i] - np.amin(yFOV) 
    tmpzc = zc[i] - np.amin(zFOV)
    cellVol[tmpxc,tmpyc,tmpzc] = intc[i] 


for i in range(np.shape(cellVol)[0]):
    cellVol[i,:,:] = ndimage.gaussian_filter(cellVol[i,:,:],sigma=1.5)
for i in range(np.shape(cellVol)[1]):
    cellVol[:,i,:] = ndimage.gaussian_filter(cellVol[:,i,:],sigma=1.5)
for i in range(np.shape(cellVol)[2]):        
    cellVol[:,:,i] =  ndimage.gaussian_filter(cellVol[:,:,i],sigma=1.5)#cv2.bilateralFilter(cellVol[:,:,i], 15, 75, 75) #ndimage.gaussian_filter(cellVol[:,:,i],sigma=3)

cellVol[cellVol>np.median(intc)*4] = np.median(intc)
cellVol = 255*(cellVol-np.amin(cellVol[cellVol != 0]))/(np.amax(cellVol[cellVol!= 0])-np.amin(cellVol[cellVol!=0]))
cellVol[cellVol<1] = 0 
cellVol[cellVol==0] = np.nan

#mlab.options.backend = 'envisage'
f = mlab.figure(bgcolor=(0,0,0),size=(fpixSize, fpixSize))
print(mlab.get_engine())
## add points 
pts = myv.tools.pipeline.scalar_scatter(x,y,z,figure=f)
pts_colormap = "turbo"
cmapD = MplColorHelper(pts_colormap,start_val=np.amin(Gre),stop_val=np.amax(Gre))
colorValue = np.linspace(np.amin(Gre),np.amax(Gre),num=np.size(Gre),dtype=np.float)
rgba = []
for i in range(np.size(colorValue)):  
        tmpIndx = np.argmin(np.abs(colorValue-Gre[i]))   
        rgba.append(cmapD.get_rgb(colorValue[tmpIndx],1.0))

pts.add_attribute(rgba, 'colors') # assign the colors to each point
pts.data.point_data.set_active_scalars('colors')
g = mlab.pipeline.glyph(pts,figure=f,line_width=1.0,transparent=True,opacity=1.0,resolution=200)
g.glyph.glyph.scale_factor = 8 # set scaling for all the points
g.glyph.scale_mode = 'data_scaling_off' # make all the points same size
print("Start Colorbar")
g.module_manager.scalar_lut_manager.show_scalar_bar = True 
g.module_manager.scalar_lut_manager.lut.range = [np.amin(Gre),np.amax(Gre)]
g.module_manager.scalar_lut_manager.number_of_labels = 5
g.module_manager.scalar_lut_manager.scalar_bar_representation.position = [0.2,0.1]
g.module_manager.scalar_lut_manager.scalar_bar_representation.position2 = [0.05,0.2] 
g.module_manager.scalar_lut_manager.label_text_property.font_family = 'times'
g.module_manager.scalar_lut_manager.scalar_bar.unconstrained_font_size = True
g.module_manager.scalar_lut_manager.scalar_bar.height = 10
if saveFig:
        g.module_manager.scalar_lut_manager.label_text_property.font_size =  820 #for save as tiff, go 360
else:
        g.module_manager.scalar_lut_manager.label_text_property.font_size =  30
g.module_manager.scalar_lut_manager.scalar_bar.label_format = "%.0f"
g.module_manager.scalar_lut_manager.scalar_bar.title = "G' (Pa)"
#g.module_manager.scalar_lut_manager.lut_mode = pts_colormap
#print(g.module_manager.scalar_lut_manager.label_text_property.__dict__)
print("***Colormap properties")
#print(g.module_manager.scalar_lut_manager.scalar_bar.__dict__)
g.actor.property.lighting = True

print("***end of Colormap properties")
#print(g.module_manager.scalar_lut_manager.__dict__)

# add cell body 
X,Y,Z = np.meshgrid(xgrid,ygrid,zgrid)
X = np.swapaxes(X,0,1)
Y = np.swapaxes(Y,0,1)
Z = np.swapaxes(Z,0,1)
print("***")
print(X.shape)
print(cellVol.shape)
print("***")
cellVolsrc = myv.tools.pipeline.scalar_field(X,Y,Z,cellVol)
vmin = 34 #np.amin(cellVol[~np.isnan(cellVol)])
vmax = 355 #np.quantile(cellVol[~np.isnan(cellVol)].flatten(),q=0.9) #np.amax(cellVol[cellVol!=0])
vol = mlab.pipeline.volume(cellVolsrc,figure=f)
vol = misc.ChangeVolColormap(vol,cmapName="Greys",vmin=vmin,vmax=vmax,alpha=1.0)
#vol.module_manager.scalar_lut_manager.lut.nan_color = 0, 0, 0, 0
#print(vol.module_manager.scalar_lut_manager.lut.__dict__)
vol.update_pipeline()
cam_para = mlab.view()
mlab.view(azimuth = 105,elevation=cam_para[1]+25) #Cell II 187; Cell IV 280
if saveFig:
        outaxe = mlab.outline(vol,line_width=30)
else:
        outaxe = mlab.outline(vol,line_width=1) # to saveas tiff, using 12
#print("******^^^^^^^^*****")
#print(outaxe.__dict__)
#print("******^^^^^^^^*****")
cam_para = mlab.view()

ax = mlab.axes(nb_labels=20,ranges=[np.amin(xFOV),np.amax(xFOV),np.amin(yFOV),np.amax(yFOV),np.amin(zFOV),np.amax(zFOV)])
ax.axes.label_format = "%.0f"
ax.axes.x_label = "X"
ax.axes.y_label = "Z"
ax.axes.z_label = "Y"
if saveFig:
        ax.axes.visibility = 1
else:
        ax.axes.visibility = 0
#print(ax.axes.__dict__)
oax = mlab.orientation_axes(xlabel="x",ylabel="z",zlabel="y") 
oax._text_property.use_tight_bounding_box = True 
oax.update_pipeline()

#print(oax.axes.position) 
cam = f.scene.camera
cam.zoom(0.85)



print("\n")
print("scene propertie ...")
f.scene.light_manager =   light_manager.LightManager(renwin=f.scene) #Property(Instance(light_manager.LightManager, record=True))
#renwin = f._renwin
#renwin.update_traits()
#l = light_manager.CameraLight(f.renwin)
#print(mlab.show_pipeline())
#print(f.scene.render.__dict__)
f.scene.light_manager.lights[0].elevation = -5
f.scene.light_manager.lights[0].azimuth = -5 
f.scene.light_manager.lights[0].intensity = 1.0
f.scene.light_manager.lights[1].elevation = -20
f.scene.light_manager.lights[1].azimuth = -110
f.scene.light_manager.lights[1].intensity =  0.7
f.scene.light_manager.lights[2].elevation = 20
f.scene.light_manager.lights[2].azimuth = 10
f.scene.light_manager.lights[2].intensity =  0.5
#print(f.scene.light_manager.lights[0].__dict__)

if saveFig:
        mlab.savefig(filename=root_path+"/VolImage2.tiff",size=(fpixSize,fpixSize))
        fid = h5py.File(root_path+"/cellVol.h5py",'w')
        fid.create_dataset("CellVol",shape=np.shape(cellVol),dtype=np.float32,data=cellVol)
        fid.close()
        print("\n")
        print("Done data savinng!")
        print("**********************")
#if not saveFig:
#        mlab.show() 


# now animate the output and save as image 
num_of_roll_ang = np.arange(0,360-cam_para[0],step=3)
zoom_factor = np.stack((np.linspace(0.85,1.7,15),np.linspace(1.6,0.75,15))) #[0.85,0.9,1.0,1.1,1.2,1.3,1.4,1.5,1.4,1.3,1.2,1.1,1.0,0.9,0.85]
zoom_factor = zoom_factor.flatten()

out_path = root_path #"E:/Tem/No_zoom_AdjustLight"
@mlab.animate(delay=10)
def anim():
    zoom_count = 0 
    for i in range(np.size(num_of_roll_ang)):
        print("cam azimuth at {}".format(cam_para[0] + num_of_roll_ang[i]))
        mlab.view(azimuth=cam_para[0] + num_of_roll_ang[i])
        #if i in range(len(zoom_factor)):
        #        cam.zoom(zoom_factor[zoom_count])
        #        zoom_count += 1 
        # concate filename with zero padded index number as suffix
        filename = os.path.join(out_path, '{}_{}{}'.format("CellII_ani", i, '.png'))
        mlab.savefig(filename=filename)
        #if i in range(len(zoom_factor)):
        #        cam.zoom(1/zoom_factor[zoom_count-1])
        yield
anim()

mlab.show()

#misc.WriteIntoGif(out_path,fps=5) #write into GIF if you need

